jaxlib==0.6.0
jax[cuda12]==0.6.0
optax==0.2.4
flax==0.10.6
ott-jax==0.5.0
flashbax==0.1.3
seaborn==0.13.2
tqdm==4.67.1
ipykernel
notebook